import torch

# 创建示例输入张量
feats = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
                       [5.0, 6.0, 7.0, 8.0],
                       [9.0, 10.0, 11.0, 12.0]],
                      [[13.0, 14.0, 15.0, 16.0],
                       [17.0, 18.0, 19.0, 20.0],
                       [21.0, 22.0, 23.0, 24.0]]])  # [batch_size=2, sequence_length=3, num_labels=4]

input_mask_weight = torch.tensor([[[1.0],
                                   [0],
                                   [0]],
                                  [[1],
                                   [0],
                                   [0]]])  # [batch_size=2, sequence_length=3, 1]

# 逐元素相乘
result = feats * input_mask_weight

# 输出结果
print("feats shape:", feats.shape)
print("input_mask_weight shape:", input_mask_weight.shape)
print("result shape:", result.shape)
print(feats)
print(input_mask_weight)
print(result)